import base64
import dataclasses
import datetime as dt
import logging
from enum import Enum
from pathlib import PurePath
from types import GeneratorType
from typing import Any, Callable, Optional, Set, Tuple, Type

import pydantic

import opik.rest_api.core.datetime_utils as datetime_utils

try:
    import numpy as np
except ImportError:
    np = None

LOGGER = logging.getLogger(__name__)

_ENCODER_EXTENSIONS: Set[Tuple[Type, Callable[[Any], Any]]] = set()


def register_encoder_extension(obj_type: Type, encoder: Callable[[Any], Any]) -> None:
    _ENCODER_EXTENSIONS.add((obj_type, encoder))


def encode(obj: Any, seen: Optional[Set[int]] = None) -> Any:
    """
    This is a modified version of the serializer generated by Fern in rest_api.core.jsonable_encoder.
    The code is simplified to serialize complex objects into a textual representation.
    It also handles cyclic references to avoid infinite recursion.
    """
    if seen is None:
        seen = set()

    if hasattr(obj, "__dict__"):
        obj_id = id(obj)
        if obj_id in seen:
            LOGGER.debug(f"Found cyclic reference to {type(obj).__name__} id={obj_id}")
            return f"<Cyclic reference to {type(obj).__name__} id={obj_id}>"
        seen.add(obj_id)

    try:
        for type_, encoder in _ENCODER_EXTENSIONS:
            if isinstance(obj, type_):
                return encode(encoder(obj), seen)

        if dataclasses.is_dataclass(obj):
            obj_dict = obj.__dict__
            return encode(obj_dict, seen)

        if isinstance(obj, pydantic.BaseModel):
            obj_dict = {**obj.__dict__}
            obj_dict.update(
                obj.__pydantic_extra__
                if isinstance(obj.__pydantic_extra__, dict)
                else {}
            )
            return encode(obj_dict, seen)

        if isinstance(obj, Enum):
            return encode(obj.value, seen)
        if isinstance(obj, PurePath):
            return str(obj)
        if isinstance(obj, (str, int, float, type(None))):
            return obj
        if isinstance(obj, dt.datetime):
            return datetime_utils.serialize_datetime(obj)
        if isinstance(obj, dt.date):
            return str(obj)
        if isinstance(obj, bytes):
            return base64.b64encode(obj).decode("utf-8")
        if isinstance(obj, dict):
            encoded_dict = {}
            allowed_keys = set(obj.keys())
            for key, value in obj.items():
                if key in allowed_keys:
                    encoded_key = encode(key, seen)
                    encoded_value = encode(value, seen)
                    encoded_dict[encoded_key] = encoded_value
            return encoded_dict

        if isinstance(obj, (list, set, frozenset, GeneratorType, tuple)):
            encoded_list = []
            for item in obj:
                encoded_list.append(encode(item, seen))
            return encoded_list

        if np is not None and isinstance(obj, np.ndarray):
            return encode(obj.tolist(), seen)

        if _is_pydantic_iterator_validator(obj):
            return "<Pydantic ValidatorIterator serialization is not supported>"

    except Exception:
        LOGGER.debug("Failed to serialize object.", exc_info=True)

    finally:
        # Once done encoding this object, remove from `seen`,
        # so the same object can appear again at a sibling branch.
        if hasattr(obj, "__dict__"):
            obj_id = id(obj)
            seen.remove(obj_id)

    data = str(obj)

    return data


def _is_pydantic_iterator_validator(obj: Any) -> bool:
    if "ValidatorIterator" == obj.__class__.__name__ and "pydantic" in obj.__module__:
        # ValidatorIterator is not defined in python code and is added to the pydantic-core
        # namespace during the runtime.
        # This class fully replaces the original generator object, so it is not possible
        # to extract any extra information from the object.

        return True

    return False
